// file:arch/x86/arch/mm/mem_pool.c
// autor:jiangxinpeng
// time:2021.1.9
// copyright:(C) 2020-2050 by jiangxinpeng,All right are reserved.

#include <lib/type.h>
#include <lib/bitmap.h>
#include <lib/list.h>
#include <lib/math.h>
#include <arch/pymem.h>
#include <arch/page.h>
#include <arch/bootmem.h>
#include <arch/mem_pool.h>

mem_range_t range[MEM_RANGE_COUNT];

// init a mem range
int MemRangeInit(uint8_t index, uint32_t start, uint32_t end)
{
        int i;
        mem_section_t *section;
        uint32_t section_size;  // section size
        uint32_t section_count; // current section nodecount
        uint32_t section_less;  // section less count
        uint8_t section_index;  // section index
        mem_node_t *nodebase;   // nodebase
        uint32_t freepage = 0;  // free page

        // check index if is true
        if (index > MEM_RANGE_COUNT)
                return -1;

        // init mem range
        (range + index)->startbase = start;
        (range + index)->endbase = end;
        (range + index)->size = end - start;
        (range + index)->pages = (range + index)->size / PAGE_SIZE;
        (range + index)->nodetable = (mem_node_t *)BootMemAlloc(MEM_NODE_SIZE * (range + index)->pages);
        if (!(range + index)->nodetable)
        {
                KPrint("[memrange] alloc for range %d nodetable failed!\n", index);
                return -1;
        }

        // init mem section
        for (i = 0; i < MEM_SECTION_MAX; i++)
        {
                // init section size
                MemSectionInit((range + index)->section + i, powi(2, i));
        }

        // start init section
        nodebase = ((mem_range_t *)(range + index))->nodetable;
        section = ((mem_range_t *)(range + index))->section;

        section_index = MEM_SECTION_MAX - 1;
        section_size = powi(2, section_index);
        freepage = (range + index)->pages;

        // if free mem less
        while (freepage > 0 && section_index)
        {
                // figure section count
                section_count = freepage / section_size;
                section_less = freepage % section_size;

                // if present big section
                if (section_count > 0)
                {
                        // KPrint("[pymem] set section idx %d section count %d freepage %d nodebase %x\n", section_index, section_count, freepage, nodebase);
                        MemSectionSet(section + section_index, section_count, nodebase);
                        // cut section alloc size
                        freepage -= (section_count * section_size);
                        // next nodebase
                        nodebase += (section_count * section_size);
                }
                else
                {
                        if (section_less > 0)
                        {
                                // KPrint("[pymem] set section idx %d to idx 0 section count %d freepage %d nodebase %x\n", section_index, section_less, freepage, nodebase);
                                // less section count put to section 0
                                MemSectionSet(section + 0, section_less, nodebase);
                                break;
                        }
                }

                // next section
                section_index--;
                section_size >>= 1;
        }

        KPrint("[pymem] init mem range %x-%x size %u(%u MB) pages %d node base %x\n", (range + index)->startbase, (range + index)->endbase, (range + index)->size, (range+index)->size/MB,(range + index)->pages, (range + index)->nodetable);
        return 0;
}

// init a section to targe size
int MemSectionInit(mem_section_t *section, uint32_t section_size)
{
        // init mem section
        section->nodecount = 0;
        section->sectionsize = section_size;

        list_init(&section->free_list_head);
}

// set a section actual node list
void MemSectionSet(mem_section_t *section, uint32_t sectioncount, uint32_t nodebase)
{
        mem_node_t *node = (mem_node_t *)nodebase;

        // set section nodecount
        section->nodecount = sectioncount;

        // start current section set all node
        for (int i = 0; i < section->nodecount; i++)
        {
                node = (mem_node_t *)nodebase + i * section->sectionsize;
                // init node
                MemNodeInit(node, 0, 0);
                // add node to list tail
                list_add_tail(&node->list, &section->free_list_head);
                // set node section
                MEM_NODE_SECTION_SET(node);
        }
}

void MemNodeInit(mem_node_t *node, uint32_t ref, uint32_t size)
{
        node->flags = 0;
        node->mem_cache = NULL;
        node->mem_group = NULL;
        node->usedcount = 0;
        node->count = size;
        node->ref = ref;
        list_init(&node->list);
}

mem_range_t *GetMemRangeByNode(mem_node_t *node)
{
        int i;
        // find in all mem range can be used
        for (i = 0; i < MEM_RANGE_COUNT; i++)
        {
                // if node in current range nodetable
                // how much pages just is nodecount
                if (node >= (range + i)->nodetable && node < (range + i)->nodetable + (range + i)->pages)
                        return (range + i);
        }
        return NULL;
}

mem_node_t *Pybase2Memnode(uint32_t pybase)
{
        mem_range_t *range = GetMemRangeByPybase(pybase);
        uint32_t index;
        uint32_t base;

        if (range == NULL)
                return NULL;
        index = (pybase - range->startbase) / PAGE_SIZE;
        base = range->nodetable + index;
        // KPrint("[pymem] %s: on range %x-%x pybase %x index %d node %x\n", __func__, range->startbase, range->endbase, pybase, index, base);
        return base;
}

uint32_t Memnode2Pybase(mem_node_t *node)
{
        mem_range_t *range = GetMemRangeByNode(node);
        uint32_t base, index;

        if (!range)
                return 0;

        index = (node - range->nodetable);
        base = range->startbase + index * PAGE_SIZE;

        // KPrint("[pymem] %s: on range %x-%x node %x index %d base %x\n", __func__, range->startbase, range->endbase, node, index, base);
        return base;
}

mem_range_t *GetMemRangeByPybase(uint32_t pybase)
{
        int i;
        // find in all mem range
        for (i = 0; i < MEM_RANGE_COUNT; i++)
        {
                if (pybase >= (range + i)->startbase && pybase < (range + i)->endbase)
                {
                        return (range + i);
                }
        }
        return NULL;
}

int MemRangeSplitSection(mem_range_t *range, mem_section_t *section)
{
        if (!range || !section)
                return 0;

        // had been free node
        if (!list_empty(&section->free_list_head))
                return 0;

        mem_section_t *tmp_section = section + 1;
        mem_section_t *top_section = &range->section[MEM_SECTION_MAX - 1];
        while (tmp_section <= top_section)
        {
                if (!list_empty(&tmp_section->free_list_head))
                        break;
                tmp_section++;
        }

        if (tmp_section > top_section)
        {
                KPrint("[pymem] no free section free left\n");
                return -1;
        }

        mem_node_t *node = list_first_owner(&tmp_section->free_list_head, mem_node_t, list);
        list_del(&node->list);
        tmp_section->nodecount--;

        MemNodeInit(node, 1, tmp_section->sectionsize / 2);

        mem_node_t *node_half = node + node->count;
        MemNodeInit(node_half, 1, node->count);

        --tmp_section; // next section
        list_add_after(&node->list, &tmp_section->free_list_head);
        list_add_after(&node_half->list, &tmp_section->free_list_head);
        node->section = tmp_section;
        node_half->section = tmp_section;
        tmp_section->nodecount++;
        tmp_section->nodecount++;

        return MemRangeSplitSection(range, section);
}

// alloc a free mem node
void *AllocMemNode(uint8_t range_index, uint32_t count)
{
        int i;
        mem_section_t *section;
        mem_node_t *node;
        // point to targe range
        mem_range_t *memrange = &range[range_index];
        // KPrint("[pymem] try to request page on range %x-%x range idx %d\n", memrange->startbase, memrange->endbase, range_index);

        // page is large
        if (count > MEM_SECTION_SIZE_MAX)
        {
                KPrint("[pymem] %s: page %d too large!\n", __func__, count);
                return NULL;
        }

        // figure mem size
        for (i = 0; i < MEM_SECTION_MAX; i++)
        {
                section = memrange->section + i;
                // if is large section
                if (section->sectionsize >= count)
                        break;
        }

        // check if have free node
        if (list_empty(&(section->free_list_head)))
        {
                // if section also large size
                if (section->sectionsize == MEM_SECTION_SIZE_MAX)
                {
                        KPrint("[pymem] no free node on max section!\n");
                        // error
                        return NULL;
                }
                else
                {
                        if (MemRangeSplitSection(memrange, section) < 0)
                        {
                                KPrint("[pymem] split section failed!\n");
                                return NULL;
                        }
                }
        }
        // alloc node from list
        node = list_first_owner(&section->free_list_head, mem_node_t, list);
        // del node from list
        list_del_init(&node->list);
        // section count -1
        section->nodecount--;

        // set node section
        MEM_NODE_SECTION_SET(node);
        // set node info
        node->flags = 1;
        node->count = count;

        return (void *)Memnode2Pybase(node);
}

// free mem node
int FreeMemNode(uint32_t pybase)
{
        mem_node_t *node = Pybase2Memnode(pybase);
        mem_section_t *section;

        // node no present
        if (!node)
        {
                return -1;
        }
        // get node section
        section = node->section;
        if (!section)
        {
                return -1;
        }
        // if node had present in free list
        if (list_find(&node->list, &section->free_list_head))
        {
                return -1;
        }
        // init node
        MemNodeInit(node, 0, 0);
        // add to free list
        list_add_tail(&node->list, &section->free_list_head);
        // section count +1
        section->nodecount++;

        return 0;
}

uint32_t _GetFreePageCountOnRange(uint32_t index)
{
        int i;
        uint32_t pagecount = 0;
        mem_range_t *range_t = range + index;
        mem_section_t *section;

        for (i = 0; i < MEM_SECTION_MAX; i++)
        {
                section = range_t->section + i;
                pagecount += (section->nodecount * section->sectionsize);
        }

        return pagecount;
}

uint32_t _GetFreePageCount()
{

        int i;
        uint32_t page = 0;

        for (i = 0; i < MEM_RANGE_COUNT; i++)
        {
                page += _GetFreePageCountOnRange(i);
        }
        return page;
}

void MemPoolDump()
{
        KPrint("[mem pool] total free page: %d\n", GetFreePageCount);
        for (int i = 0; i < MEM_RANGE_COUNT; i++)
        {
                KPrint("range %d free page: %d\n", i, _GetFreePageCountOnRange(i));
        }
}